# dim_infer.py
import json
import glob
import os


def infer_dims_from_file(data_path_pattern):
    """
    data_path_pattern: 类似于 "offline_data/2s3z/good/*.json"
    返回 global_dim, local_dim, action_dim
    """
    files = sorted(glob.glob(data_path_pattern))
    if not files:
        raise FileNotFoundError(f"No JSON files found in {data_path_pattern}")

    with open(files[0], "r") as f:
        data = json.load(f)

    # MADT episode 格式: [ [ [global_obs, local_obs, act, reward, done, avail_act], ... ] ]
    global_dim = len(data[0][0][0])  # global state
    local_dim = len(data[0][0][1])  # local obs
    action_dim = len(data[0][0][5])  # available actions → 推断动作维度

    return global_dim, local_dim, action_dim
